# graph_utils/kernels.py
import math
# The kernels can be accelerated with Numba for performance.  When Numba
# is unavailable at import time the functions fall back to pure Python
# implementations.  This ensures that the package can still be imported
# without requiring an external JIT compiler.
try:
    import numba  # type: ignore[import]
    _HAS_NUMBA = True
except Exception:
    numba = None  # type: ignore[assignment]
    _HAS_NUMBA = False

# Define a decorator that applies numba.njit if available.  Otherwise
# returns the function unchanged.
def _maybe_jit(signature=None, **jit_kwargs):  # pragma: no cover
    """
    Return a decorator that applies ``numba.njit`` to the wrapped function if
    Numba is available.  The signature argument is currently ignored and
    left for future extension; Numba will infer the signature automatically.
    """
    def decorator(fn):  # pragma: no cover
        if _HAS_NUMBA:
            # Apply numba.njit without passing a signature.  Rely on type inference.
            return numba.njit(**jit_kwargs)(fn)
        return fn
    return decorator
import numpy as np

# -------------------- UMAP heavy-tail default --------------------
DEFAULT_KERNEL_DEFAULT_PARAMS = np.array([1.0, 1.0], dtype=np.float32)  # a, b

@_maybe_jit()
def default_kernel(s: float, p: np.ndarray) -> float:
    a = float(p[0]); b = float(p[1])
    return 1.0 / (1.0 + a * (s ** b))

@_maybe_jit()
def default_attr_coeff(s: float, p: np.ndarray) -> float:
    a = float(p[0]); b = float(p[1])
    se = s if s > 1e-12 else 1e-12
    return - (a * b) * (se ** (b - 1.0)) / (1.0 + a * (se ** b))

@_maybe_jit()
def default_rep_coeff(s: float, p: np.ndarray) -> float:
    a = float(p[0]); b = float(p[1])
    se = s if s > 1e-12 else 1e-12
    num = (a * b) * (se ** (b - 1.0))
    den = (1.0 + a * (se ** b))
    return num / (den * den)  # = a b s^(b-1) / (1 + a s^b)^2

# ---------- Built-in kernel families ----------
UMAP_NAMES = ("umap",)
STUDENT_T_NAMES = ("student_t", "cauchy")
EXP_NAMES = ("gaussian", "exp", "laplace")
EXP_SQRT_NAMES = ("exp_sqrt",)
POWERLAW_NAMES = ("powerlaw",)

@_maybe_jit()
def student_t_Q(s: float, p: np.ndarray) -> float:
    nu = float(p[0] if p[0] > 1e-12 else 1e-12)
    c = 0.5 * (nu + 1.0)
    return (1.0 + s / nu) ** (-c)

@_maybe_jit()
def student_t_dlogQ(s: float, p: np.ndarray) -> float:
    nu = float(p[0] if p[0] > 1e-12 else 1e-12)
    c = 0.5 * (nu + 1.0)
    return - c / (nu + s)

@_maybe_jit()
def student_t_negdQ(s: float, p: np.ndarray) -> float:
    nu = float(p[0] if p[0] > 1e-12 else 1e-12)
    c = 0.5 * (nu + 1.0)
    base = 1.0 + s / nu
    return (c / nu) * (base ** (-c - 1.0))

@_maybe_jit()
def exp_Q(s: float, p: np.ndarray) -> float:
    lam = float(p[0] if p[0] > 1e-12 else 1e-12)
    return math.exp(-lam * s)

@_maybe_jit()
def exp_dlogQ(s: float, p: np.ndarray) -> float:
    lam = float(p[0] if p[0] > 1e-12 else 1e-12)
    return -lam

@_maybe_jit()
def exp_negdQ(s: float, p: np.ndarray) -> float:
    lam = float(p[0] if p[0] > 1e-12 else 1e-12)
    return lam * math.exp(-lam * s)

@_maybe_jit()
def exp_sqrt_Q(s: float, p: np.ndarray) -> float:
    lam = float(p[0] if p[0] > 1e-12 else 1e-12)
    se = s + 1e-12
    return math.exp(-lam * math.sqrt(se))

@_maybe_jit()
def exp_sqrt_dlogQ(s: float, p: np.ndarray) -> float:
    lam = float(p[0] if p[0] > 1e-12 else 1e-12)
    se = s + 1e-12
    return - lam / (2.0 * math.sqrt(se))

@_maybe_jit()
def exp_sqrt_negdQ(s: float, p: np.ndarray) -> float:
    lam = float(p[0] if p[0] > 1e-12 else 1e-12)
    se = s + 1e-12
    q = math.exp(-lam * math.sqrt(se))
    return (lam / (2.0 * math.sqrt(se))) * q

@_maybe_jit()
def powerlaw_Q(s: float, p: np.ndarray) -> float:
    g = float(p[0] if p[0] > 1e-12 else 1e-12)
    return (1.0 + s) ** (-g)

@_maybe_jit()
def powerlaw_dlogQ(s: float, p: np.ndarray) -> float:
    g = float(p[0] if p[0] > 1e-12 else 1e-12)
    return - g / (1.0 + s)

@_maybe_jit()
def powerlaw_negdQ(s: float, p: np.ndarray) -> float:
    g = float(p[0] if p[0] > 1e-12 else 1e-12)
    return g * (1.0 + s) ** (-g - 1.0)

def builtin_kernel_names():
    return set(UMAP_NAMES + STUDENT_T_NAMES + EXP_NAMES + EXP_SQRT_NAMES + POWERLAW_NAMES)

def make_builtin_kernel(name: str, param_dict):
    key = name.lower().strip()
    if key in UMAP_NAMES:
        a = float(param_dict.get("a", 1.0))
        b = float(param_dict.get("b", 1.0))
        return default_kernel, default_attr_coeff, default_rep_coeff, np.array([a, b], dtype=np.float32)
    if key in STUDENT_T_NAMES:
        nu = float(param_dict.get("nu", 1.0 if key == "cauchy" else 5.0))
        return student_t_Q, student_t_dlogQ, student_t_negdQ, np.array([nu], dtype=np.float32)
    if key in EXP_NAMES:
        lam = float(param_dict.get("lam", 1.0))
        return exp_Q, exp_dlogQ, exp_negdQ, np.array([lam], dtype=np.float32)
    if key in EXP_SQRT_NAMES:
        lam = float(param_dict.get("lam", 1.0))
        return exp_sqrt_Q, exp_sqrt_dlogQ, exp_sqrt_negdQ, np.array([lam], dtype=np.float32)
    if key in POWERLAW_NAMES:
        gamma = float(param_dict.get("gamma", 1.0))
        return powerlaw_Q, powerlaw_dlogQ, powerlaw_negdQ, np.array([gamma], dtype=np.float32)
    raise ValueError(f"Unknown built-in kernel family: {name!r}")
